#!/bin/bash

accelerate launch --multi_gpu --mixed_precision=bf16 --num_processes=2 run_distillation_pt.py \
    --model_name_or_path distil-whisper/large-32-2 \
    --teacher_model_name_or_path openai/whisper-large-v2 \
    --train_dataset_config_name all+all+all+l \
    --train_dataset_samples 2.9+10.4+14.9+226.6 \
    --train_dataset_name librispeech_asr+librispeech_asr+librispeech_asr+gigaspeech-l \
    --train_split_name train.clean.100+train.clean.360+train.other.500+train \
    --eval_dataset_name librispeech_asr+librispeech_asr+gigaspeech-l \
    --eval_dataset_config_name all+all+l \
    --eval_split_name validation.clean+validation.other+validation \
    --eval_text_column_name text+text+text \
    --eval_steps 2500 \
    --save_steps 2500 \
    --warmup_steps 50 \
    --learning_rate 0.0001 \
    --lr_scheduler_type constant_with_warmup \
    --logging_steps 25 \
    --save_total_limit 1 \
    --max_steps 10000 \
    --wer_threshold 10 \
    --per_device_train_batch_size 64 \
    --gradient_accumulation_steps 2 \
    --per_device_eval_batch_size 64 \
    --dataloader_num_workers 16 \
    --cache_dir /fsx/sanchit/cache \
    --dataset_cache_dir /fsx/sanchit/cache \
    --dtype bfloat16 \
    --output_dir ./ \
    --wandb_project distil-whisper-training \
    --do_train \
    --do_eval \
    --gradient_checkpointing \
    --overwrite_output_dir \
    --predict_with_generate \
    --freeze_encoder \
    --streaming
